#!/usr/bin/python3

import argparse
import itertools

from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from PIL import Image
import torch

from models import Generator
from models import Discriminator
from models import make_model, extract_feature
from utils import ReplayBuffer
from utils import LambdaLR
from utils import weights_init_normal
from datasets import ImageDataset
import numpy as np
import os

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batchSize', type=int, default=4, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
parser.add_argument('--decay_epoch', type=int, default=100, help='epoch to start linearly decaying the learning rate to 0')
parser.add_argument('--size', type=int, default=256, help='size of the data crop (squared assumed)')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--cuda', action='store_true', help='use GPU computation')
parser.add_argument('--n_cpu', type=int, default=2, help='number of cpu threads to use during batch generation')
parser.add_argument('--initialization', action='store_true', help='True: initial phase; False: Train the Generator and Discriminator')
parser.add_argument('--load_model', type=str,default=None, help='The path of the pretrain model')
opt = parser.parse_args()
#print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###### Definition of variables ######
# Networks
netGen = Generator(opt.input_nc, opt.output_nc)
netDis = Discriminator(opt.input_nc)

if opt.load_model is not None:
    checkpoint = torch.load(opt.load_model)
    netGen.load_state_dict(checkpoint['Gen_state_dict'])
    netDis.load_state_dict(checkpoint['Dis_state_dict'])

if opt.cuda:
    netGen.cuda()
    netDis.cuda()

if opt.load_model == None:
    netGen.apply(weights_init_normal)
    netDis.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_Gen = torch.optim.Adam(itertools.chain(netGen.parameters()),
                                lr=opt.lr, betas=(0.5, 0.999))
optimizer_Dis = torch.optim.Adam(netDis.parameters(), lr=opt.lr, betas=(0.5, 0.999))

if opt.load_model != None and not "initial_checkpoint" in opt.load_model:
    lr_scheduler_Gen = checkpoint['learning_rate_Gen']
    lr_scheduler_Dis = checkpoint['learning_rate_Dis']
else:
    lr_scheduler_Gen = torch.optim.lr_scheduler.LambdaLR(optimizer_Gen, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_Dis = torch.optim.lr_scheduler.LambdaLR(optimizer_Dis, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_photo = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_cartoon_blur = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_cartoon = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)

input_photo_test = Tensor(1, opt.input_nc, opt.size, opt.size)

#This tensor I have changed 
target_real = Variable(Tensor(opt.batchSize,1).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt.batchSize,1).fill_(0.0), requires_grad=False)

fake_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(opt.size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataset = ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True)
dataloader = DataLoader(dataset,
                        batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)

transforms_test_ = [ transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader_test = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_test_, mode='test'), 
                        batch_size=1, shuffle=False, num_workers=opt.n_cpu)

vggModel = make_model()
#define the phase of Initialization or training GAN
Initial_phase = opt.initialization

# Create output dirs if they don't exist
if not os.path.exists('output'):
    os.makedirs('output')
if not os.path.exists('output/cartoon_Gen'):
    os.makedirs('output/cartoon_Gen')

###################################
if __name__ == '__main__':
    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        netGen.train()
        loss_G_list = []
        loss_D_list = []
        loss_D_loss_all = 0.0
        loss_D_real_all = 0.0
        loss_D_fake_all = 0.0
        loss_D_blur_all = 0.0
        loss_GAN_cartoon_all = 0.0
        loss_content_all = 0.0
        for i, batch in enumerate(dataloader):
            # Set model input
            photo = Variable(input_photo.copy_(batch['input_photo']),requires_grad=True)
            cartoon_blur = Variable(input_cartoon_blur.copy_(batch['input_cartoon_blur']))
            cartoon = Variable(input_cartoon.copy_(batch['input_cartoon']))

            ###### Generators ######
            optimizer_Gen.zero_grad()
            fake_cartoon = netGen(photo)

            #content loss
            photo_content = extract_feature(vggModel, photo)
            fake_content = extract_feature(vggModel, fake_cartoon)

            loss_content = criterion_identity(photo_content, fake_content) * 2.5

            if Initial_phase:
                loss_G = loss_content
                loss_G.backward()
                optimizer_Gen.step()
                loss_G_value = loss_G.cpu().detach().numpy()
                loss_G_list.append(loss_G_value)
                print("\r i" , i+1 , "/" , len(dataloader) , " epoch:", epoch+1,"/20" ," loss_G:",loss_G_value, end='', flush=True)
            else:
                # Identity loss
                # Gen(cartoon) should equal cartoon if real cartoon is fed
                # this loss is not in the paper
                #sameCartoon = netGen(cartoon)
                #loss_identity = criterion_identity(sameCartoon, cartoon)*5.0

                # GAN loss
                pred_fake = netDis(fake_cartoon)
                loss_GAN_cartoon = criterion_GAN(pred_fake, target_real)

                # Total loss
                #loss_G = loss_identity + loss_GAN_cartoon + loss_content
                loss_G = loss_GAN_cartoon + loss_content
                loss_G.backward()
                
                optimizer_Gen.step()
                ###################################

                ###### Discriminator ######
                optimizer_Dis.zero_grad()

                # Real loss
                pred_real_cartoon = netDis(cartoon)
                loss_D_real = criterion_GAN(pred_real_cartoon, target_real)

                # Fake loss
                fake_cartoon = fake_buffer.push_and_pop(fake_cartoon)
                pred_fake_cartoon = netDis(fake_cartoon.detach())
                loss_D_fake = criterion_GAN(pred_fake_cartoon, target_fake)

                # Blur loss
                pred_blur_cartoon = netDis(cartoon_blur)
                loss_D_blur = criterion_GAN(pred_blur_cartoon, target_fake)

                # Total loss
                loss_Dis = (loss_D_real + loss_D_fake + loss_D_blur) * 1/3 
                loss_Dis.backward()

                optimizer_Dis.step()
                ###################################
                loss_G_value = loss_G.cpu().detach().numpy()
                loss_G_list.append(loss_G_value)
                loss_D_value = loss_Dis.cpu().detach().numpy()
                loss_D_list.append(loss_D_value)
                loss_D_loss_all = loss_D_loss_all + loss_D_value 
                loss_D_real_all = loss_D_real_all + loss_D_real.cpu().detach().numpy()
                loss_D_fake_all = loss_D_fake_all + loss_D_fake.cpu().detach().numpy()
                loss_D_blur_all = loss_D_blur_all + loss_D_blur.cpu().detach().numpy()
                loss_GAN_cartoon_all = loss_GAN_cartoon_all + loss_GAN_cartoon.cpu().detach().numpy()
                loss_content_all = loss_content_all + loss_content.cpu().detach().numpy()
                print("\r i" , i+1 , "/" , len(dataloader) , " epoch:", epoch+1,"/200" ,
                    " loss_G:%.2f loss_D:%.6f" %(loss_G_value, loss_D_value),end='', flush=True)

        if Initial_phase:
            if epoch == 19:
                Initial_phase = False
                print("\nThe Initialization phase have done...")
                # Save models checkpoints
                state = {
                    'epoch':epoch,
                    'Gen_state_dict': netGen.state_dict(),
                    'Dis_state_dict': netDis.state_dict(),
                    'learning_rate_Gen': lr_scheduler_Gen,
                    'learning_rate_Dis': lr_scheduler_Dis
                }
                torch.save(state, './output/initial_checkpoint.pth' )
                break
        else:
            print("\nepoch:", epoch,' loss_G',np.mean(loss_G_list),' loss_GAN_cartoon', loss_GAN_cartoon_all / float(i), ' loss_content', loss_content_all / float(i))
            print("epoch:", epoch,' loss_D',loss_D_loss_all / float(i),' loss_D_real',loss_D_real_all / float(i) ,
                ' loss_D_fake',loss_D_fake_all / float(i) ,' loss_D_blur',loss_D_blur_all / float(i) , "\n")
            # Update learning rates
            lr_scheduler_Gen.step()
            lr_scheduler_Dis.step()
            
            if epoch % 1 == 0:
                # Save models checkpoints
                state = {
                    'epoch':epoch,
                    'Gen_state_dict': netGen.state_dict(),
                    'Dis_state_dict': netDis.state_dict(),
                    'learning_rate_Gen': lr_scheduler_Gen,
                    'learning_rate_Dis': lr_scheduler_Dis
                }
                torch.save(state, './output/checkpoint%d.pth' % epoch)
            
            netGen.eval()
            for i, batch in enumerate(dataloader_test):
                # Set model input
                photo_test = Variable(input_photo_test.copy_(batch['input_photo']))

                # Generate output
                fake_cartoon_test = 0.5*(netGen(photo_test).data + 1.0)

                # Save image files
                save_image(fake_cartoon_test, './output/cartoon_Gen/epoch_%d_%04d.png' % (epoch,i+1))
        ###################################
    